Add Mooncake integration#1342
Conversation
|
Closing temporarily to iterate on my fork without wasting CI here |
|
@ChrisRackauckas can you approve the workflow? |
|
|
||
| @static if isdefined(Mooncake, :FriendlyTangentCache) # checks Mooncake >= v0.5.25 | ||
| # see https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/998 | ||
| function Mooncake.friendly_tangent_cache(x::Union{SArray,MArray}) |
There was a problem hiding this comment.
Does this actually work with SArray where elements are Symmetric SArrays?
What exactly is Mooncake doing to the returned value to get a mutable type? copy on SArray returns an immutable again.
There was a problem hiding this comment.
Does this actually work with SArray where elements are Symmetric SArrays?
Probably not, one might need some kind of deep copy. Or we could restrict it to Union{SArray{<:IEEEFloat}, MArray{<:IEEEFloat}} to be safe, albeit incomplete.
What exactly is Mooncake doing to the returned value to get a mutable type? copy on SArray returns an immutable again.
Contrary to what the name "cache" suggests, I don't think the output of friendly_tangent_cache has to be mutable. From what I understand reading Claude's comments in chalk-lab/Mooncake.jl#1103, it seems that this function is meant to output a kind of template for reconstructing the gradient type we want.
To clarify, I had nothing to do with the breaking changes in Mooncake v0.5.25, I actually disapprove of them and I don't fully understand them (especially since they were LLM-generated). I'm just trying to keep the bare minimum working (Mooncake's gradient returning a SArray in simple cases)
There was a problem hiding this comment.
I don't know enough about Mooncake to have an opinion yet. I just think a few more complex examples regarding how this is all supposed to work for nested types would be really helpful.
I did take a look at the Julia AD ecosystem a couple of years ago so I know a bit about how it works, and I still use it sometimes, though I don't see good areas to contribute to there sadly. Too many conflicting goals and points of view on how things should work.
There was a problem hiding this comment.
I think @yebai is the expert on this new Mooncake API. Perhaps it would be better as a Mooncake extension, since it is a very modest amount of code and StaticArrays is a very old and stable package?
There was a problem hiding this comment.
StaticArrays.jl already has a ChainRulesCore extension so adding Mooncake wouldn't be unprecedented. My main worry is that this friendly tangent thing doesn't seem particularly stable and well-tested, even in comparison with ChainRules. So it's unclear many iterations on the idea are still needed. I'd suggest figuring out more complex examples outside of StaticArrays.jl first, and make an extension when it's clear that the API works well across multiple nested array types from different packages.
|
Closing this PR for now since it is unclear to me that augmenting StaticArrays this way is the right design choice. This code should live in Mooncake |
Mooncake is a recently developed autodiff library for Julia. Since it released v0.5.25 (more specifically since this PR), Mooncake needs specific overloads to return a
StaticArraywhen taking the gradient off(x::StaticArray). This PR adds the necessary machinery without committing type piracy. It would be better suited in Mooncake itself but the maintainer refuses to consider support even for standard libraries and base types, so I don't think StaticArrays.jl stands a chance.See also: